In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random

Do preprocessing

In [3]:
#/hpf/largeprojects/MICe/mdagys/Cnp-GFP_Study/2019-06-10_labelled/raw
raw_dir = Path("raw")
raws = raw_dir.ls()
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
# D-R_Z were the initial ones to be labelled, kinda more sloppy.
# images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name and "D-R_Z" not in raw_path.name])
# labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name and "D-R_Z" not in raw_path.name])

processed_dir = Path("processed")
l=224
In [ ]:
random.seed(23)
empty = 0
popu = 0
cutoff=1

for image_path,label_path in zip(images,labels):
    image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
    label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)

    if image.shape != label.shape:
        raise ValueError(image_path.as_posix() + label_path.as_posix())
    i_max = image.shape[0]//l
    j_max = image.shape[1]//l

# If the cells were labelled as 255, or something else mistakenly, instead of 1.
    label[label!=0]=1

    for i in range(i_max):
        for j in range(j_max):
            cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
            cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]

            if (cropped_label!=0).any():
                popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            else:
                empty+=1
                if (random.random() < cutoff):
                    continue
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)

            cv.imwrite(cropped_image_path.as_posix(), cropped_image)
            cv.imwrite(cropped_label_path.as_posix(), cropped_label)
In [ ]:
print(popu)
print(empty)

Train NN

In [4]:
torch.cuda.set_device(1)
In [5]:
codes = ["NOT-CELL", "CELL"]
bs = 4
#bs=16 and l=224 will use ~7300MiB for resnet34  before unfreezing
#bs=4 and l=224 use ~11500MiB for resnet50 before unfreezing
In [6]:
transforms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_zoom = 1, #consider
    max_rotate = 45,
    max_lighting = None,
    max_warp = None,
    p_affine = 0.75,
    p_lighting = 0.75)
In [7]:
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())

src = (
    SegmentationItemList.from_folder(processed_dir)
    .filter_by_func(lambda fname:
                    'image' in Path(fname).name and "empty" not in Path(fname).name)
    .split_by_rand_pct(valid_pct=0.20, seed=1)
    .label_from_func(get_label_from_image, classes=codes)
)
data = (
    src.transform(transforms, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)
In [ ]:
data.show_batch(2, figsize=(10,7))
In [8]:
# models.resnet34
model_path = Path("../../models")
learn = unet_learner(data, models.resnet50, metrics=partial(dice, iou=True))
learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([1,1]).cuda())
In [9]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [10]:
lr = 5e-5
learn.fit_one_cycle(25, lr)
epoch train_loss valid_loss dice time
0 0.023488 0.019877 0.209598 03:38
1 0.018001 0.025803 0.035352 03:36
2 0.018150 0.015012 0.328353 03:37
3 0.016106 0.016819 0.186077 03:35
4 0.017777 0.014573 0.381587 03:35
5 0.017257 0.015137 0.395617 03:37
6 0.014920 0.014138 0.373547 03:39
7 0.013938 0.012927 0.348498 03:37
8 0.013879 0.013520 0.305893 03:37
9 0.015615 0.012981 0.394207 03:36
10 0.014143 0.012430 0.399886 03:37
11 0.012947 0.012572 0.413324 03:36
12 0.013041 0.014963 0.293644 03:39
13 0.012636 0.011983 0.400940 03:40
14 0.013891 0.012266 0.397941 03:38
15 0.012137 0.011898 0.423105 03:37
16 0.012745 0.011981 0.406798 03:40
17 0.014052 0.011582 0.427009 03:36
18 0.012233 0.011777 0.409914 03:33
19 0.012691 0.011993 0.436071 03:36
20 0.011273 0.011689 0.437079 03:36
21 0.011301 0.011569 0.419610 03:37
22 0.011809 0.011517 0.440828 03:36
23 0.011088 0.011925 0.423847 03:36
24 0.012269 0.011598 0.423712 03:35
In [11]:
learn.save(model_path/"2019-07-02_RESNET50_IOU0.42_1stage")
In [ ]:
learn.load(model_path/"2019-07-02_RESNET50_IOU0.41_1stage");
In [12]:
learn.unfreeze()
In [13]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [14]:
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(15, lrs)
epoch train_loss valid_loss dice time
0 0.012620 0.012545 0.390357 03:41
1 0.012682 0.012582 0.378745 03:39
2 0.013211 0.012378 0.413408 03:41
3 0.013809 0.013032 0.343238 03:40
4 0.013848 0.012497 0.418637 03:41
5 0.013817 0.013194 0.402935 03:39
6 0.012792 0.012365 0.408522 03:39
7 0.012762 0.012415 0.396147 03:40
8 0.012971 0.012935 0.355501 03:41
9 0.013049 0.012801 0.373378 03:41
10 0.014654 0.012237 0.395947 03:39
11 0.012884 0.012289 0.407171 03:38
12 0.013566 0.012326 0.394652 03:37
13 0.012465 0.012253 0.398518 03:39
14 0.012282 0.012242 0.399497 03:47
In [ ]:
learn.save(models_path/"2019-06-14_RESNET34_IOU0.25_2stage")
In [ ]:
learn.export(file = models_path/"2019-06-14_RESNET34_IOU0.25_2stage.pkl")

Check

In [ ]:
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])
In [15]:
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()
In [ ]:
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
#     print(torch.max(preds[0][i][1]))

# Image(preds[1][0]).show()
In [16]:
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
    N = learn.data.valid_ds.__len__()
else:
    raise ValueError()

xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]
In [ ]:
print(xs[0].px.shape)
print(ys[0].px.shape)
print(p0s[0].px.shape)
print(p1s[0].px.shape)
In [17]:
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [18]:
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
    plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Greys", alpha=1)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [19]:
learn.show_results(rows=16, ds_type=DatasetType.Train)
In [20]:
learn.show_results(rows=16)